DeepEM: A Deep Neural Network for DEM Inversion

by Paul Wright$^{1}$, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey

$^{1}$ University of Glasgow; email: paul@pauljwright.co.uk

The intensity observed through optically-thin SDO/AIA filters (94 Ã…, 131 Ã…, 171 Ã…, 193 Ã…, 211 Ã…, 335 Ã…) can be related to the temperature distribution of the solar corona (the differential emission measure; DEM) as

\begin{equation} g_{i} = \int_{T} K_{i}(T) \xi(T) dT \, . \end{equation}

In this equation, $g_{i}$ is the DN s$^{-1}$ px$^{-1}$ value in the $i$th SDO/AIA channel. This intensity corresponds to the $K_{i}(T)$ temperature response function, and the DEM, $\xi(T)$, is in units of cm$^{-5}$ K$^{-1}$. The matrix formulation of this integral equation can be represented in the form $\vec{g} = {\bf K}\vec{\xi}$, however this problem is an ill-posed inverse problem, and any attempt to directly recover $\vec{\xi}$ leads to significant noise amplication.

There are numerous methods to tackle mathematical problems of this kind, and there are an increasing number of methods in the literature for recovering the differential emission measure from SDO/AIA observations, including methods based tecniques such as Tikhonov Regularisation (Hannah & Kontar 2012; https://doi.org/10.1051/0004-6361/201117576), on the concept of sparsity (Cheung et al 2015; https://doi.org/10.1088/0004-637X/807/2/143). In the following notebook, we will demonstrate how a simple 1x1 2D convolutional neural network allows for significant improvement in computational speed for DEM inversion with similar fidelity to the method used for training (Basis Pursuit). Additionally this method, DeepEM, provides solutions with values of emission measure >0 in every temperature bin.

DeepEM: A Deep Learning Approach for DEM Inversion

Paul Wright, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey


In this chapter we will introduce a Deep Learning approach for DEM Inversion. For this notebook, DeepEM is a trained on one set of SDO/AIA observations (six optically thin channels; $6 \times N \times N$) and DEM solutions (in 18 temperature bins from log$_{10}$T = 5.5 - 7.2, $18 \times N \times N$; Cheung et al 2015) at a resolution of $512 \times 512$ ($N = 512$) using a $1 \times 1$ 2D Convolutional Neural Network with a single hidden layer.

The DeepEM method presented here takes every DEM solution with no regards to the quality or existence of the solution. As will be demonstrated, when this method is trained with a single set images and DEM solutions, the DeepEM solutions have a similar fidelity to Sparse Inversion (with a significantly increased computation speed), and additionally, the DeepEM solutions find positive solutions at every pixel, and reduced noise in the DEM solutions.

In [1]:
#This notebook has been written in PyTorch
import os
import json
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.utils.data import DataLoader

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [2]:
#cudaize determines if a gpu is available for training and testing
def cudaize(obj):
    return obj.cuda() if torch.cuda.is_available() else obj
In [3]:
def em_scale(y):
    return np.sqrt(y/1e25)

def em_unscale(y):
    return 1e25*(y*y)

def img_scale(x):
    x2 = x
    bad = np.where(x2 <= 0.0)
    x2[bad] = 0.0
    return np.sqrt(x2)

def img_unscale(x):
    return x*x 

Step 1: Obtain Data and Sparse Inversion Solutions for Training

We first load the SDO/AIA images and Basis Pursuit DEM maps.

N.B. While this simplified version of DeepEM has been trained on DEM maps from Basis Pursuit (Cheung et al. 2015), we actively encourage the readers to try their favourite method for DEM inversion!

In [4]:
aia_files = ['AIA_DEM_2011-01-27','AIA_DEM_2011-02-22','AIA_DEM_2011-03-20']
em_cube_files = aia_files

for k, (afile, emfile) in enumerate(zip(aia_files, em_cube_files)):
    afile_name = os.path.join('./DeepEM_Data/', afile + '.aia.npy')
    emfile_name = os.path.join('./DeepEM_Data/', emfile + '.emcube.npy')
    if k == 0:
        X = np.load(afile_name)
        y = np.load(emfile_name)
 
        X = np.zeros((len(aia_files), X.shape[0], X.shape[1], X.shape[2]))
        y = np.zeros((len(em_cube_files), y.shape[0], y.shape[1], y.shape[2]))
        
        nlgT = y.shape[0]
        lgtaxis = np.arange(y.shape[1])*0.1 + 5.5
        
    X[k] = np.load(afile_name)
    y[k] = np.load(emfile_name) 

Step 2: Define the Model

We first define the model as a 1x1 2D Convolutional Neural Network (CNN) with a kernel size of 1x1. The model accepts a data cube of $6 \times N \times N$ (SDO/AIA data), and returns a data cube of $18 \times N \times N$ (DEM). which when trained, will transform the input (each pixel of the 6 SDO/AIA channels; $6 \times 1 \times 1$) to the output (DEM at each pixel; $18 \times 1 \times 1$).

In [5]:
model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(), #Activation function
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1))

model = cudaize(model)

Step 3: Train the Model

For training our CNN we select one SDO/AIA data cube ($6\times512\times512$) and the corresponding Sparse Inversion DEM output ($18\times512\times512$). In the case presented here, we train the CNN on an image of the Sun obtained on 27-01-2011, validate on an image of the Sun obtained one synodic rotation later (+26 days; 22-02-2011), and finally test on an image another 26 days later (20-03-2011).

In [6]:
X = img_scale(X)
y = em_scale(y)

X_train = X[0:1] 
y_train = y[0:1] 

X_val = X[1:2] 
y_val = y[1:2] 

X_test = X[2:3] 
y_test = y[2:3]

Plotting SDO/AIA Observations ${\it vs.}$ Basis Pursuit DEM bins

For the test data set, the SDO/AIA images for 171 Ã…, 211 Ã…, and 94 Ã…, and the corresponding DEM bins near the peak sensitivity in these relative isothermal channel (logT = 6.3, 5.9) are shown in Figure 1. Figure 1 shows a set of SDO/AIA images (171 Ã…, 211 Ã…, and 94 Ã… [Left to Right]) with the corresponding DEM maps for temperature bins there are near the peak sensitivity of the SDO/AIA channel. Furthermore, it is clear from the DEM maps that a number of pixels that are $zero$. These pixels are primarily located off-disk, but there are a number of pixels on-disk that show this behaviour.

In [7]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(y_test[0,8,:,:],vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(y_test[0,4,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(y_test[0,15,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 1: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below. In the DEM bins (bottom) it is clear that there are some pixels that have solutions of DEM = $zero$, as explicitly seen as dark regions/clusters of pixels on and off disk.


To implement training and testing of our model, we first define a DEMdata class, and define functions for training and validation/test: train_model, and valtest_model.

N.B. It is not necessary to train the model, and if required, the trained model can be loaded to the cpu as follows:

model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1))
model = cudaize(model)

dem_model_file = 'DeepEM_CNN_HelioML.pth'
model.load_state_dict(torch.load(dem_model_file))

Once you have loaded the the model, skip to Step 4: Testing the Model.

In [8]:
class DEMdata(nn.Module):
    def __init__(self, xtrain, ytrain, xtest, ytest, xval, yval, split='train'):
        
        if split == 'train':
            self.x = xtrain
            self.y = ytrain
        if split == 'val':
            self.x = xval
            self.y = yval
        if split == 'test':
            self.x = xtest
            self.y = ytest
            
    def __getitem__(self, index):
        return torch.from_numpy(self.x[index]).type(torch.FloatTensor), torch.from_numpy(self.y[index]).type(torch.FloatTensor)

    def __len__(self):
        return self.x.shape[0]
In [9]:
def train_model(dem_loader, criterion, optimizer, epochs=500):
    model.train()
    train_loss_all_batches = []
    train_loss_epoch = []
    train_val = []
    for k in range(epochs):
        count_ = 0
        avg_loss = 0
        # =================== progress indicator ==============
        if k % ((epochs + 1) // 4) == 0:
            print('[{0}]: {1:.1f}% complete: '.format(k, k / epochs * 100))
        # =====================================================
        for img, dem in dem_loader:
            count_ += 1
            optimizer.zero_grad()
            # =================== forward =====================
            img = cudaize(img)
            dem = cudaize(dem)

            output = model(img) 
            loss = criterion(output, dem)

            loss.backward()
            optimizer.step()
            
            train_loss_all_batches.append(loss.item())
            avg_loss += loss.item()
        # =================== Validation ===================
        dem_data_val = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='val')
        dem_loader_val = DataLoader(dem_data_val, batch_size=1)
        val_loss, dummy, dem_pred_val, dem_in_test_val = valtest_model(dem_loader_val, criterion)
        
        train_loss_epoch.append(avg_loss/count_)
        train_val.append(val_loss)
        
        print('Epoch: ', k, 'trn_loss: ', avg_loss/count_, 'val_loss: ', train_val[k])
            
    torch.save(model.state_dict(), 'DeepEM_CNN_HelioML.pth')
    return train_loss_epoch, train_val

def valtest_model(dem_loader, criterion):

    model.eval()
    
    val_loss = 0
    count = 0
    test_loss = []
    for img, dem in dem_loader:
        count += 1
        # =================== forward =====================
        img = cudaize(img)
        dem = cudaize(dem)
        
        output = model(img)
        loss = criterion(output, dem)
        test_loss.append(loss.item())
        val_loss += loss.item()
        
    return val_loss/count, test_loss, output, dem

We choose the Adam optimiser with a learning rate of 1e-4, and weight_decay set to 1e-9. We use Mean Squared Error (MSE) between the Sparse Inversion DEM map and the DeepEM map as our loss function.

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-9); 
criterion = cudaize(nn.MSELoss())

Using the defined functions, dem_data will return the training data, and this will be loaded by the DataLoader with batch_size=1 (one 512 x 512 image per batch). For each epoch, train_loss and valdn_loss will be returned by train_model

In [11]:
dem_data = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='train')
dem_loader = DataLoader(dem_data, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
train_loss, valdn_loss = train_model(dem_loader, criterion, optimizer, epochs=500)
ttime = "Training time = {0} seconds".format(time.time()-t0)
print(ttime)
[0]: 0.0% complete: 
Epoch:  0 trn_loss:  2.688119649887085 val_loss:  2.9241230487823486
Epoch:  1 trn_loss:  2.5150320529937744 val_loss:  2.7400925159454346
Epoch:  2 trn_loss:  2.35073184967041 val_loss:  2.565598487854004
Epoch:  3 trn_loss:  2.1951470375061035 val_loss:  2.400751829147339
Epoch:  4 trn_loss:  2.0484015941619873 val_loss:  2.24615740776062
Epoch:  5 trn_loss:  1.9110569953918457 val_loss:  2.102147102355957
Epoch:  6 trn_loss:  1.7834622859954834 val_loss:  1.9676543474197388
Epoch:  7 trn_loss:  1.6646268367767334 val_loss:  1.8420829772949219
Epoch:  8 trn_loss:  1.5539453029632568 val_loss:  1.7250733375549316
Epoch:  9 trn_loss:  1.4509862661361694 val_loss:  1.6162289381027222
Epoch:  10 trn_loss:  1.3554034233093262 val_loss:  1.5151138305664062
Epoch:  11 trn_loss:  1.266849160194397 val_loss:  1.4214166402816772
Epoch:  12 trn_loss:  1.1849992275238037 val_loss:  1.3348586559295654
Epoch:  13 trn_loss:  1.1095739603042603 val_loss:  1.254988670349121
Epoch:  14 trn_loss:  1.0402250289916992 val_loss:  1.1812376976013184
Epoch:  15 trn_loss:  0.976538360118866 val_loss:  1.1132055521011353
Epoch:  16 trn_loss:  0.9180724620819092 val_loss:  1.050828456878662
Epoch:  17 trn_loss:  0.8646981120109558 val_loss:  0.9940065741539001
Epoch:  18 trn_loss:  0.8163536787033081 val_loss:  0.9426607489585876
Epoch:  19 trn_loss:  0.7729448676109314 val_loss:  0.8966353535652161
Epoch:  20 trn_loss:  0.7342903017997742 val_loss:  0.8555873036384583
Epoch:  21 trn_loss:  0.7000507116317749 val_loss:  0.8194212913513184
Epoch:  22 trn_loss:  0.6700736880302429 val_loss:  0.7879579663276672
Epoch:  23 trn_loss:  0.6442007422447205 val_loss:  0.7609621286392212
Epoch:  24 trn_loss:  0.6222249269485474 val_loss:  0.7381443381309509
Epoch:  25 trn_loss:  0.6038868427276611 val_loss:  0.7191906571388245
Epoch:  26 trn_loss:  0.5888946056365967 val_loss:  0.7037503719329834
Epoch:  27 trn_loss:  0.5769066214561462 val_loss:  0.6914080381393433
Epoch:  28 trn_loss:  0.5675306916236877 val_loss:  0.6817150712013245
Epoch:  29 trn_loss:  0.5603471994400024 val_loss:  0.6741949319839478
Epoch:  30 trn_loss:  0.5549194812774658 val_loss:  0.6683857440948486
Epoch:  31 trn_loss:  0.5508288741111755 val_loss:  0.663855791091919
Epoch:  32 trn_loss:  0.5476891398429871 val_loss:  0.6601954698562622
Epoch:  33 trn_loss:  0.5451509952545166 val_loss:  0.6570515632629395
Epoch:  34 trn_loss:  0.5429093241691589 val_loss:  0.6541380882263184
Epoch:  35 trn_loss:  0.5407118201255798 val_loss:  0.6512284278869629
Epoch:  36 trn_loss:  0.5383663177490234 val_loss:  0.648151695728302
Epoch:  37 trn_loss:  0.5357363820075989 val_loss:  0.6447973847389221
Epoch:  38 trn_loss:  0.5327411890029907 val_loss:  0.6411278247833252
Epoch:  39 trn_loss:  0.5293562412261963 val_loss:  0.6371627449989319
Epoch:  40 trn_loss:  0.5256128311157227 val_loss:  0.6329593658447266
Epoch:  41 trn_loss:  0.5215826034545898 val_loss:  0.6285680532455444
Epoch:  42 trn_loss:  0.5173172354698181 val_loss:  0.6240283846855164
Epoch:  43 trn_loss:  0.5128658413887024 val_loss:  0.6194215416908264
Epoch:  44 trn_loss:  0.5083035826683044 val_loss:  0.6148301362991333
Epoch:  45 trn_loss:  0.5037232041358948 val_loss:  0.6103371977806091
Epoch:  46 trn_loss:  0.4992277920246124 val_loss:  0.6060026288032532
Epoch:  47 trn_loss:  0.49486982822418213 val_loss:  0.6018410921096802
Epoch:  48 trn_loss:  0.4906587600708008 val_loss:  0.5978648662567139
Epoch:  49 trn_loss:  0.48660966753959656 val_loss:  0.5940942168235779
Epoch:  50 trn_loss:  0.482749879360199 val_loss:  0.5905488729476929
Epoch:  51 trn_loss:  0.47909900546073914 val_loss:  0.5872213244438171
Epoch:  52 trn_loss:  0.4756618142127991 val_loss:  0.5840893387794495
Epoch:  53 trn_loss:  0.47243034839630127 val_loss:  0.5811282992362976
Epoch:  54 trn_loss:  0.469387412071228 val_loss:  0.5782973766326904
Epoch:  55 trn_loss:  0.4665057361125946 val_loss:  0.5755579471588135
Epoch:  56 trn_loss:  0.46376025676727295 val_loss:  0.5728747844696045
Epoch:  57 trn_loss:  0.4611271619796753 val_loss:  0.5702178478240967
Epoch:  58 trn_loss:  0.4585789144039154 val_loss:  0.5675610899925232
Epoch:  59 trn_loss:  0.45609772205352783 val_loss:  0.56488436460495
Epoch:  60 trn_loss:  0.4536675214767456 val_loss:  0.5621587038040161
Epoch:  61 trn_loss:  0.45126640796661377 val_loss:  0.5593568682670593
Epoch:  62 trn_loss:  0.44888320565223694 val_loss:  0.5564686059951782
Epoch:  63 trn_loss:  0.4465186893939972 val_loss:  0.5534869432449341
Epoch:  64 trn_loss:  0.4441678822040558 val_loss:  0.550425112247467
Epoch:  65 trn_loss:  0.4418522119522095 val_loss:  0.5472999811172485
Epoch:  66 trn_loss:  0.4395861029624939 val_loss:  0.5441355109214783
Epoch:  67 trn_loss:  0.43736910820007324 val_loss:  0.5409624576568604
Epoch:  68 trn_loss:  0.4352066218852997 val_loss:  0.5378114581108093
Epoch:  69 trn_loss:  0.4331071078777313 val_loss:  0.5347063541412354
Epoch:  70 trn_loss:  0.43108075857162476 val_loss:  0.5316702723503113
Epoch:  71 trn_loss:  0.4291342496871948 val_loss:  0.5287131667137146
Epoch:  72 trn_loss:  0.4272521734237671 val_loss:  0.5258299112319946
Epoch:  73 trn_loss:  0.4254133105278015 val_loss:  0.5230257511138916
Epoch:  74 trn_loss:  0.42360565066337585 val_loss:  0.520313024520874
Epoch:  75 trn_loss:  0.4218240976333618 val_loss:  0.5177026987075806
Epoch:  76 trn_loss:  0.4200654923915863 val_loss:  0.5151981711387634
Epoch:  77 trn_loss:  0.4183226227760315 val_loss:  0.5128055810928345
Epoch:  78 trn_loss:  0.4165956676006317 val_loss:  0.5105316638946533
Epoch:  79 trn_loss:  0.4148896634578705 val_loss:  0.5083771347999573
Epoch:  80 trn_loss:  0.4132143557071686 val_loss:  0.5063313245773315
Epoch:  81 trn_loss:  0.4115768373012543 val_loss:  0.5043664574623108
Epoch:  82 trn_loss:  0.40997231006622314 val_loss:  0.5024513006210327
Epoch:  83 trn_loss:  0.4083913564682007 val_loss:  0.5005679726600647
Epoch:  84 trn_loss:  0.40683066844940186 val_loss:  0.4987040162086487
Epoch:  85 trn_loss:  0.40529051423072815 val_loss:  0.49685144424438477
Epoch:  86 trn_loss:  0.40377312898635864 val_loss:  0.4950086176395416
Epoch:  87 trn_loss:  0.40228208899497986 val_loss:  0.4931758642196655
Epoch:  88 trn_loss:  0.4008176624774933 val_loss:  0.49135148525238037
Epoch:  89 trn_loss:  0.399376779794693 val_loss:  0.4895342290401459
Epoch:  90 trn_loss:  0.3979552090167999 val_loss:  0.48772361874580383
Epoch:  91 trn_loss:  0.39654940366744995 val_loss:  0.4859222173690796
Epoch:  92 trn_loss:  0.39515841007232666 val_loss:  0.4841350018978119
Epoch:  93 trn_loss:  0.3937835693359375 val_loss:  0.48236745595932007
Epoch:  94 trn_loss:  0.3924272358417511 val_loss:  0.4806228280067444
Epoch:  95 trn_loss:  0.3910899758338928 val_loss:  0.4789029359817505
Epoch:  96 trn_loss:  0.3897719383239746 val_loss:  0.477206289768219
Epoch:  97 trn_loss:  0.38847193121910095 val_loss:  0.4755311608314514
Epoch:  98 trn_loss:  0.3871863782405853 val_loss:  0.47387662529945374
Epoch:  99 trn_loss:  0.3859134912490845 val_loss:  0.4722418189048767
Epoch:  100 trn_loss:  0.384651780128479 val_loss:  0.47062763571739197
Epoch:  101 trn_loss:  0.3833991289138794 val_loss:  0.4690333604812622
Epoch:  102 trn_loss:  0.3821508586406708 val_loss:  0.46744483709335327
Epoch:  103 trn_loss:  0.38090336322784424 val_loss:  0.46583521366119385
Epoch:  104 trn_loss:  0.3796510398387909 val_loss:  0.4641898572444916
Epoch:  105 trn_loss:  0.37839025259017944 val_loss:  0.4625108540058136
Epoch:  106 trn_loss:  0.3771190047264099 val_loss:  0.4608005881309509
Epoch:  107 trn_loss:  0.3758341372013092 val_loss:  0.459060937166214
Epoch:  108 trn_loss:  0.37453746795654297 val_loss:  0.4572887420654297
Epoch:  109 trn_loss:  0.3732394576072693 val_loss:  0.4554760754108429
Epoch:  110 trn_loss:  0.3719472289085388 val_loss:  0.45364904403686523
Epoch:  111 trn_loss:  0.3706713020801544 val_loss:  0.451806902885437
Epoch:  112 trn_loss:  0.369411826133728 val_loss:  0.4499753415584564
Epoch:  113 trn_loss:  0.3681654930114746 val_loss:  0.4481814205646515
Epoch:  114 trn_loss:  0.3669349253177643 val_loss:  0.4464665949344635
Epoch:  115 trn_loss:  0.3657364845275879 val_loss:  0.44483789801597595
Epoch:  116 trn_loss:  0.36455634236335754 val_loss:  0.4432632327079773
Epoch:  117 trn_loss:  0.36336278915405273 val_loss:  0.4417327046394348
Epoch:  118 trn_loss:  0.36215659976005554 val_loss:  0.44024181365966797
Epoch:  119 trn_loss:  0.36094820499420166 val_loss:  0.43877968192100525
Epoch:  120 trn_loss:  0.3597407937049866 val_loss:  0.4373362362384796
Epoch:  121 trn_loss:  0.3585361838340759 val_loss:  0.4358990788459778
Epoch:  122 trn_loss:  0.35733479261398315 val_loss:  0.4344557523727417
Epoch:  123 trn_loss:  0.35613512992858887 val_loss:  0.43299537897109985
Epoch:  124 trn_loss:  0.3549352288246155 val_loss:  0.4315100908279419
[125]: 25.0% complete: 
Epoch:  125 trn_loss:  0.35373228788375854 val_loss:  0.42999908328056335
Epoch:  126 trn_loss:  0.3525252938270569 val_loss:  0.42846551537513733
Epoch:  127 trn_loss:  0.35131415724754333 val_loss:  0.4269155263900757
Epoch:  128 trn_loss:  0.35009902715682983 val_loss:  0.4253571927547455
Epoch:  129 trn_loss:  0.3488815128803253 val_loss:  0.423798143863678
Epoch:  130 trn_loss:  0.34766173362731934 val_loss:  0.4222450852394104
Epoch:  131 trn_loss:  0.3464396297931671 val_loss:  0.42070311307907104
Epoch:  132 trn_loss:  0.34521520137786865 val_loss:  0.41917359828948975
Epoch:  133 trn_loss:  0.3439885079860687 val_loss:  0.41765573620796204
Epoch:  134 trn_loss:  0.3427591323852539 val_loss:  0.4161473214626312
Epoch:  135 trn_loss:  0.34152740240097046 val_loss:  0.4146471917629242
Epoch:  136 trn_loss:  0.34029221534729004 val_loss:  0.41315263509750366
Epoch:  137 trn_loss:  0.33905351161956787 val_loss:  0.41166189312934875
Epoch:  138 trn_loss:  0.33781251311302185 val_loss:  0.4101731479167938
Epoch:  139 trn_loss:  0.336570143699646 val_loss:  0.4086863100528717
Epoch:  140 trn_loss:  0.33532702922821045 val_loss:  0.4072016477584839
Epoch:  141 trn_loss:  0.3340844213962555 val_loss:  0.40572142601013184
Epoch:  142 trn_loss:  0.3328438997268677 val_loss:  0.4042476415634155
Epoch:  143 trn_loss:  0.33160579204559326 val_loss:  0.40278181433677673
Epoch:  144 trn_loss:  0.33037033677101135 val_loss:  0.40132516622543335
Epoch:  145 trn_loss:  0.3291375935077667 val_loss:  0.3998756408691406
Epoch:  146 trn_loss:  0.3279068171977997 val_loss:  0.3984292149543762
Epoch:  147 trn_loss:  0.3266780376434326 val_loss:  0.39698001742362976
Epoch:  148 trn_loss:  0.3254525065422058 val_loss:  0.3955231010913849
Epoch:  149 trn_loss:  0.3242305517196655 val_loss:  0.3940570056438446
Epoch:  150 trn_loss:  0.32301077246665955 val_loss:  0.39257463812828064
Epoch:  151 trn_loss:  0.3217824399471283 val_loss:  0.39106231927871704
Epoch:  152 trn_loss:  0.32052335143089294 val_loss:  0.3895217180252075
Epoch:  153 trn_loss:  0.3192353844642639 val_loss:  0.3879607319831848
Epoch:  154 trn_loss:  0.31792783737182617 val_loss:  0.3863987326622009
Epoch:  155 trn_loss:  0.31660303473472595 val_loss:  0.3848625719547272
Epoch:  156 trn_loss:  0.31528016924858093 val_loss:  0.3833721876144409
Epoch:  157 trn_loss:  0.31397897005081177 val_loss:  0.38190972805023193
Epoch:  158 trn_loss:  0.31271278858184814 val_loss:  0.3804458975791931
Epoch:  159 trn_loss:  0.31146588921546936 val_loss:  0.37898194789886475
Epoch:  160 trn_loss:  0.3102269172668457 val_loss:  0.37752583622932434
Epoch:  161 trn_loss:  0.308996319770813 val_loss:  0.3760853409767151
Epoch:  162 trn_loss:  0.3077775239944458 val_loss:  0.37466737627983093
Epoch:  163 trn_loss:  0.3065733313560486 val_loss:  0.3732720613479614
Epoch:  164 trn_loss:  0.3053828477859497 val_loss:  0.3718949258327484
Epoch:  165 trn_loss:  0.304203599691391 val_loss:  0.37052688002586365
Epoch:  166 trn_loss:  0.3030341863632202 val_loss:  0.36915624141693115
Epoch:  167 trn_loss:  0.3018725514411926 val_loss:  0.367771714925766
Epoch:  168 trn_loss:  0.3007168769836426 val_loss:  0.36636537313461304
Epoch:  169 trn_loss:  0.2995656132698059 val_loss:  0.3649352788925171
Epoch:  170 trn_loss:  0.29841864109039307 val_loss:  0.3634849190711975
Epoch:  171 trn_loss:  0.29727718234062195 val_loss:  0.3620240390300751
Epoch:  172 trn_loss:  0.2961430847644806 val_loss:  0.36056458950042725
Epoch:  173 trn_loss:  0.2950180172920227 val_loss:  0.35911887884140015
Epoch:  174 trn_loss:  0.293902724981308 val_loss:  0.35769572854042053
Epoch:  175 trn_loss:  0.292797327041626 val_loss:  0.3563007116317749
Epoch:  176 trn_loss:  0.29170113801956177 val_loss:  0.35493606328964233
Epoch:  177 trn_loss:  0.29061368107795715 val_loss:  0.3536009192466736
Epoch:  178 trn_loss:  0.2895352840423584 val_loss:  0.35229235887527466
Epoch:  179 trn_loss:  0.2884659171104431 val_loss:  0.35100576281547546
Epoch:  180 trn_loss:  0.2874058783054352 val_loss:  0.34973642230033875
Epoch:  181 trn_loss:  0.286355584859848 val_loss:  0.34847965836524963
Epoch:  182 trn_loss:  0.2853151559829712 val_loss:  0.34723153710365295
Epoch:  183 trn_loss:  0.284284770488739 val_loss:  0.3459901213645935
Epoch:  184 trn_loss:  0.28326448798179626 val_loss:  0.34475380182266235
Epoch:  185 trn_loss:  0.2822543680667877 val_loss:  0.3435213565826416
Epoch:  186 trn_loss:  0.2812545895576477 val_loss:  0.3422924280166626
Epoch:  187 trn_loss:  0.2802650034427643 val_loss:  0.3410668969154358
Epoch:  188 trn_loss:  0.27928587794303894 val_loss:  0.33984512090682983
Epoch:  189 trn_loss:  0.2783172130584717 val_loss:  0.3386283218860626
Epoch:  190 trn_loss:  0.2773592472076416 val_loss:  0.3374187648296356
Epoch:  191 trn_loss:  0.27641215920448303 val_loss:  0.33621901273727417
Epoch:  192 trn_loss:  0.27547597885131836 val_loss:  0.33503231406211853
Epoch:  193 trn_loss:  0.2745507061481476 val_loss:  0.3338610529899597
Epoch:  194 trn_loss:  0.27363622188568115 val_loss:  0.3327072858810425
Epoch:  195 trn_loss:  0.27273234724998474 val_loss:  0.3315722942352295
Epoch:  196 trn_loss:  0.2718389332294464 val_loss:  0.33045694231987
Epoch:  197 trn_loss:  0.2709559202194214 val_loss:  0.32936158776283264
Epoch:  198 trn_loss:  0.2700830399990082 val_loss:  0.3282855153083801
Epoch:  199 trn_loss:  0.26922017335891724 val_loss:  0.3272271156311035
Epoch:  200 trn_loss:  0.2683671712875366 val_loss:  0.32618460059165955
Epoch:  201 trn_loss:  0.2675240933895111 val_loss:  0.3251557946205139
Epoch:  202 trn_loss:  0.26669085025787354 val_loss:  0.3241388499736786
Epoch:  203 trn_loss:  0.26586735248565674 val_loss:  0.3231317102909088
Epoch:  204 trn_loss:  0.2650529146194458 val_loss:  0.32213228940963745
Epoch:  205 trn_loss:  0.2642463445663452 val_loss:  0.321138471364975
Epoch:  206 trn_loss:  0.2634473443031311 val_loss:  0.320149689912796
Epoch:  207 trn_loss:  0.2626544237136841 val_loss:  0.31916573643684387
Epoch:  208 trn_loss:  0.2618657350540161 val_loss:  0.31818705797195435
Epoch:  209 trn_loss:  0.26107966899871826 val_loss:  0.31722477078437805
Epoch:  210 trn_loss:  0.26030483841896057 val_loss:  0.3162919282913208
Epoch:  211 trn_loss:  0.25955092906951904 val_loss:  0.31537410616874695
Epoch:  212 trn_loss:  0.25880908966064453 val_loss:  0.3144550323486328
Epoch:  213 trn_loss:  0.2580699324607849 val_loss:  0.3135058581829071
Epoch:  214 trn_loss:  0.257336288690567 val_loss:  0.3125093877315521
Epoch:  215 trn_loss:  0.25659897923469543 val_loss:  0.3114919364452362
Epoch:  216 trn_loss:  0.2558664083480835 val_loss:  0.31049755215644836
Epoch:  217 trn_loss:  0.25514885783195496 val_loss:  0.30955731868743896
Epoch:  218 trn_loss:  0.2544456422328949 val_loss:  0.3086819052696228
Epoch:  219 trn_loss:  0.25375044345855713 val_loss:  0.30786967277526855
Epoch:  220 trn_loss:  0.25306236743927 val_loss:  0.3071036636829376
Epoch:  221 trn_loss:  0.25238242745399475 val_loss:  0.3063604533672333
Epoch:  222 trn_loss:  0.25170987844467163 val_loss:  0.30561569333076477
Epoch:  223 trn_loss:  0.25104430317878723 val_loss:  0.30485162138938904
Epoch:  224 trn_loss:  0.25038576126098633 val_loss:  0.30405986309051514
Epoch:  225 trn_loss:  0.24973466992378235 val_loss:  0.3032422363758087
Epoch:  226 trn_loss:  0.24909022450447083 val_loss:  0.3024085760116577
Epoch:  227 trn_loss:  0.24845078587532043 val_loss:  0.30157291889190674
Epoch:  228 trn_loss:  0.24781635403633118 val_loss:  0.30074918270111084
Epoch:  229 trn_loss:  0.24718907475471497 val_loss:  0.2999477684497833
Epoch:  230 trn_loss:  0.2465706467628479 val_loss:  0.29917412996292114
Epoch:  231 trn_loss:  0.24596036970615387 val_loss:  0.2984277307987213
Epoch:  232 trn_loss:  0.24535663425922394 val_loss:  0.2977038025856018
Epoch:  233 trn_loss:  0.24475733935832977 val_loss:  0.296994686126709
Epoch:  234 trn_loss:  0.24416306614875793 val_loss:  0.2962947487831116
Epoch:  235 trn_loss:  0.24357406795024872 val_loss:  0.295600563287735
Epoch:  236 trn_loss:  0.2429903894662857 val_loss:  0.2949105501174927
Epoch:  237 trn_loss:  0.24241246283054352 val_loss:  0.2942263185977936
Epoch:  238 trn_loss:  0.24184152483940125 val_loss:  0.29354920983314514
Epoch:  239 trn_loss:  0.24127882719039917 val_loss:  0.2928784191608429
Epoch:  240 trn_loss:  0.24072442948818207 val_loss:  0.2922115623950958
Epoch:  241 trn_loss:  0.2401776909828186 val_loss:  0.2915470004081726
Epoch:  242 trn_loss:  0.23963740468025208 val_loss:  0.2908848524093628
Epoch:  243 trn_loss:  0.23910215497016907 val_loss:  0.2902263402938843
Epoch:  244 trn_loss:  0.23857004940509796 val_loss:  0.28957512974739075
Epoch:  245 trn_loss:  0.23803827166557312 val_loss:  0.28893783688545227
Epoch:  246 trn_loss:  0.2375062108039856 val_loss:  0.288314551115036
Epoch:  247 trn_loss:  0.23698054254055023 val_loss:  0.28768908977508545
Epoch:  248 trn_loss:  0.23646627366542816 val_loss:  0.28704777359962463
Epoch:  249 trn_loss:  0.2359556406736374 val_loss:  0.28637880086898804
[250]: 50.0% complete: 
Epoch:  250 trn_loss:  0.23544162511825562 val_loss:  0.28570863604545593
Epoch:  251 trn_loss:  0.23492182791233063 val_loss:  0.2850691080093384
Epoch:  252 trn_loss:  0.23440149426460266 val_loss:  0.28448808193206787
Epoch:  253 trn_loss:  0.2338988035917282 val_loss:  0.2839711904525757
Epoch:  254 trn_loss:  0.23342770338058472 val_loss:  0.2834737002849579
Epoch:  255 trn_loss:  0.23296424746513367 val_loss:  0.2829468846321106
Epoch:  256 trn_loss:  0.23249179124832153 val_loss:  0.28237414360046387
Epoch:  257 trn_loss:  0.23201648890972137 val_loss:  0.2817578613758087
Epoch:  258 trn_loss:  0.23154255747795105 val_loss:  0.2811194956302643
Epoch:  259 trn_loss:  0.23107165098190308 val_loss:  0.2804938852787018
Epoch:  260 trn_loss:  0.2306051254272461 val_loss:  0.2799045443534851
Epoch:  261 trn_loss:  0.2301434427499771 val_loss:  0.2793571352958679
Epoch:  262 trn_loss:  0.22968538105487823 val_loss:  0.27884623408317566
Epoch:  263 trn_loss:  0.22922979295253754 val_loss:  0.2783586382865906
Epoch:  264 trn_loss:  0.2287774682044983 val_loss:  0.27787908911705017
Epoch:  265 trn_loss:  0.22832989692687988 val_loss:  0.27740275859832764
Epoch:  266 trn_loss:  0.22788827121257782 val_loss:  0.2769321799278259
Epoch:  267 trn_loss:  0.22745053470134735 val_loss:  0.27646493911743164
Epoch:  268 trn_loss:  0.2270139902830124 val_loss:  0.2759941816329956
Epoch:  269 trn_loss:  0.22657924890518188 val_loss:  0.27551162242889404
Epoch:  270 trn_loss:  0.2261488437652588 val_loss:  0.275012731552124
Epoch:  271 trn_loss:  0.22572213411331177 val_loss:  0.27450335025787354
Epoch:  272 trn_loss:  0.22529800236225128 val_loss:  0.27399855852127075
Epoch:  273 trn_loss:  0.22487518191337585 val_loss:  0.2735140323638916
Epoch:  274 trn_loss:  0.2244543731212616 val_loss:  0.2730548679828644
Epoch:  275 trn_loss:  0.2240365594625473 val_loss:  0.272614061832428
Epoch:  276 trn_loss:  0.22362110018730164 val_loss:  0.2721751928329468
Epoch:  277 trn_loss:  0.2232065200805664 val_loss:  0.27172210812568665
Epoch:  278 trn_loss:  0.22279077768325806 val_loss:  0.27125057578086853
Epoch:  279 trn_loss:  0.22237205505371094 val_loss:  0.27077147364616394
Epoch:  280 trn_loss:  0.22194460034370422 val_loss:  0.2703031003475189
Epoch:  281 trn_loss:  0.22151556611061096 val_loss:  0.26987069845199585
Epoch:  282 trn_loss:  0.22108761966228485 val_loss:  0.2694975733757019
Epoch:  283 trn_loss:  0.22065941989421844 val_loss:  0.26916855573654175
Epoch:  284 trn_loss:  0.22024355828762054 val_loss:  0.26882386207580566
Epoch:  285 trn_loss:  0.21983742713928223 val_loss:  0.26837828755378723
Epoch:  286 trn_loss:  0.21943111717700958 val_loss:  0.26783284544944763
Epoch:  287 trn_loss:  0.21902237832546234 val_loss:  0.2673019468784332
Epoch:  288 trn_loss:  0.2186221033334732 val_loss:  0.2668565511703491
Epoch:  289 trn_loss:  0.21822905540466309 val_loss:  0.26649364829063416
Epoch:  290 trn_loss:  0.2178383469581604 val_loss:  0.26615750789642334
Epoch:  291 trn_loss:  0.21744932234287262 val_loss:  0.2657807767391205
Epoch:  292 trn_loss:  0.21706141531467438 val_loss:  0.26534274220466614
Epoch:  293 trn_loss:  0.21667568385601044 val_loss:  0.2648792564868927
Epoch:  294 trn_loss:  0.21629418432712555 val_loss:  0.26443809270858765
Epoch:  295 trn_loss:  0.21591295301914215 val_loss:  0.2640412747859955
Epoch:  296 trn_loss:  0.2155279517173767 val_loss:  0.26367607712745667
Epoch:  297 trn_loss:  0.2151441127061844 val_loss:  0.26330527663230896
Epoch:  298 trn_loss:  0.21476680040359497 val_loss:  0.26289504766464233
Epoch:  299 trn_loss:  0.21439379453659058 val_loss:  0.26245296001434326
Epoch:  300 trn_loss:  0.21402162313461304 val_loss:  0.2620246112346649
Epoch:  301 trn_loss:  0.21365049481391907 val_loss:  0.2616528570652008
Epoch:  302 trn_loss:  0.21328073740005493 val_loss:  0.26134201884269714
Epoch:  303 trn_loss:  0.21291309595108032 val_loss:  0.2610540986061096
Epoch:  304 trn_loss:  0.21254874765872955 val_loss:  0.2607361078262329
Epoch:  305 trn_loss:  0.21218609809875488 val_loss:  0.26036006212234497
Epoch:  306 trn_loss:  0.21182256937026978 val_loss:  0.25994497537612915
Epoch:  307 trn_loss:  0.2114584743976593 val_loss:  0.2595381736755371
Epoch:  308 trn_loss:  0.211094930768013 val_loss:  0.2591785788536072
Epoch:  309 trn_loss:  0.21073204278945923 val_loss:  0.25887352228164673
Epoch:  310 trn_loss:  0.21037206053733826 val_loss:  0.25859758257865906
Epoch:  311 trn_loss:  0.21001870930194855 val_loss:  0.25831127166748047
Epoch:  312 trn_loss:  0.209669828414917 val_loss:  0.257991760969162
Epoch:  313 trn_loss:  0.20931921899318695 val_loss:  0.2576506435871124
Epoch:  314 trn_loss:  0.2089664340019226 val_loss:  0.2573127746582031
Epoch:  315 trn_loss:  0.20861411094665527 val_loss:  0.2569901645183563
Epoch:  316 trn_loss:  0.20826417207717896 val_loss:  0.25667569041252136
Epoch:  317 trn_loss:  0.20791688561439514 val_loss:  0.256353497505188
Epoch:  318 trn_loss:  0.20757150650024414 val_loss:  0.2560157775878906
Epoch:  319 trn_loss:  0.20722709596157074 val_loss:  0.25567400455474854
Epoch:  320 trn_loss:  0.20688463747501373 val_loss:  0.2553477883338928
Epoch:  321 trn_loss:  0.20654481649398804 val_loss:  0.2550462782382965
Epoch:  322 trn_loss:  0.20620745420455933 val_loss:  0.25476107001304626
Epoch:  323 trn_loss:  0.20587190985679626 val_loss:  0.2544742822647095
Epoch:  324 trn_loss:  0.20553790032863617 val_loss:  0.2541729509830475
Epoch:  325 trn_loss:  0.20520471036434174 val_loss:  0.253858357667923
Epoch:  326 trn_loss:  0.20487090945243835 val_loss:  0.25353994965553284
Epoch:  327 trn_loss:  0.204534113407135 val_loss:  0.2532232999801636
Epoch:  328 trn_loss:  0.2041948437690735 val_loss:  0.2529051899909973
Epoch:  329 trn_loss:  0.20385676622390747 val_loss:  0.2525775134563446
Epoch:  330 trn_loss:  0.20352190732955933 val_loss:  0.25223982334136963
Epoch:  331 trn_loss:  0.20318947732448578 val_loss:  0.25190407037734985
Epoch:  332 trn_loss:  0.20285820960998535 val_loss:  0.25158560276031494
Epoch:  333 trn_loss:  0.2025279402732849 val_loss:  0.25128623843193054
Epoch:  334 trn_loss:  0.20220059156417847 val_loss:  0.25099074840545654
Epoch:  335 trn_loss:  0.20187628269195557 val_loss:  0.2506760060787201
Epoch:  336 trn_loss:  0.20155423879623413 val_loss:  0.25033557415008545
Epoch:  337 trn_loss:  0.2012346386909485 val_loss:  0.2499915212392807
Epoch:  338 trn_loss:  0.200918510556221 val_loss:  0.24967315793037415
Epoch:  339 trn_loss:  0.20060597360134125 val_loss:  0.24939261376857758
Epoch:  340 trn_loss:  0.20029594004154205 val_loss:  0.24913766980171204
Epoch:  341 trn_loss:  0.19998779892921448 val_loss:  0.24888208508491516
Epoch:  342 trn_loss:  0.19968181848526 val_loss:  0.24860519170761108
Epoch:  343 trn_loss:  0.19937792420387268 val_loss:  0.2483035773038864
Epoch:  344 trn_loss:  0.19907572865486145 val_loss:  0.24798685312271118
Epoch:  345 trn_loss:  0.19877523183822632 val_loss:  0.24766618013381958
Epoch:  346 trn_loss:  0.1984766125679016 val_loss:  0.24734963476657867
Epoch:  347 trn_loss:  0.19817982614040375 val_loss:  0.24704153835773468
Epoch:  348 trn_loss:  0.19788461923599243 val_loss:  0.2467457503080368
Epoch:  349 trn_loss:  0.1975906789302826 val_loss:  0.24646472930908203
Epoch:  350 trn_loss:  0.19729815423488617 val_loss:  0.24619251489639282
Epoch:  351 trn_loss:  0.19700710475444794 val_loss:  0.24591964483261108
Epoch:  352 trn_loss:  0.19671736657619476 val_loss:  0.24563691020011902
Epoch:  353 trn_loss:  0.1964288055896759 val_loss:  0.2453446090221405
Epoch:  354 trn_loss:  0.19614174962043762 val_loss:  0.2450520247220993
Epoch:  355 trn_loss:  0.1958564966917038 val_loss:  0.24476885795593262
Epoch:  356 trn_loss:  0.19557303190231323 val_loss:  0.2444980889558792
Epoch:  357 trn_loss:  0.1952911764383316 val_loss:  0.24423456192016602
Epoch:  358 trn_loss:  0.19501115381717682 val_loss:  0.24396920204162598
Epoch:  359 trn_loss:  0.19473320245742798 val_loss:  0.24369573593139648
Epoch:  360 trn_loss:  0.19445715844631195 val_loss:  0.24341392517089844
Epoch:  361 trn_loss:  0.19418294727802277 val_loss:  0.24312813580036163
Epoch:  362 trn_loss:  0.19391047954559326 val_loss:  0.2428419291973114
Epoch:  363 trn_loss:  0.19363979995250702 val_loss:  0.24255935847759247
Epoch:  364 trn_loss:  0.1933707892894745 val_loss:  0.24228501319885254
Epoch:  365 trn_loss:  0.19310331344604492 val_loss:  0.24202346801757812
Epoch:  366 trn_loss:  0.19283705949783325 val_loss:  0.24177473783493042
Epoch:  367 trn_loss:  0.1925719529390335 val_loss:  0.24153225123882294
Epoch:  368 trn_loss:  0.19230785965919495 val_loss:  0.24128712713718414
Epoch:  369 trn_loss:  0.19204463064670563 val_loss:  0.24103564023971558
Epoch:  370 trn_loss:  0.19178226590156555 val_loss:  0.24078035354614258
Epoch:  371 trn_loss:  0.19152027368545532 val_loss:  0.2405284196138382
Epoch:  372 trn_loss:  0.19125840067863464 val_loss:  0.24028804898262024
Epoch:  373 trn_loss:  0.19099804759025574 val_loss:  0.24006754159927368
Epoch:  374 trn_loss:  0.19074712693691254 val_loss:  0.23983997106552124
[375]: 75.0% complete: 
Epoch:  375 trn_loss:  0.19050052762031555 val_loss:  0.2395990639925003
Epoch:  376 trn_loss:  0.19024918973445892 val_loss:  0.23935382068157196
Epoch:  377 trn_loss:  0.1899942308664322 val_loss:  0.23910999298095703
Epoch:  378 trn_loss:  0.18973974883556366 val_loss:  0.23887300491333008
Epoch:  379 trn_loss:  0.18949176371097565 val_loss:  0.23863989114761353
Epoch:  380 trn_loss:  0.1892472505569458 val_loss:  0.23840448260307312
Epoch:  381 trn_loss:  0.18900302052497864 val_loss:  0.2381656914949417
Epoch:  382 trn_loss:  0.1887587159872055 val_loss:  0.23792512714862823
Epoch:  383 trn_loss:  0.18851391971111298 val_loss:  0.23768244683742523
Epoch:  384 trn_loss:  0.1882692277431488 val_loss:  0.23743894696235657
Epoch:  385 trn_loss:  0.18802841007709503 val_loss:  0.2371944785118103
Epoch:  386 trn_loss:  0.18779060244560242 val_loss:  0.23695240914821625
Epoch:  387 trn_loss:  0.18755215406417847 val_loss:  0.2367142289876938
Epoch:  388 trn_loss:  0.18731243908405304 val_loss:  0.2364780306816101
Epoch:  389 trn_loss:  0.18707317113876343 val_loss:  0.23624277114868164
Epoch:  390 trn_loss:  0.18683668971061707 val_loss:  0.2360091507434845
Epoch:  391 trn_loss:  0.18660296499729156 val_loss:  0.23577788472175598
Epoch:  392 trn_loss:  0.18637025356292725 val_loss:  0.23554576933383942
Epoch:  393 trn_loss:  0.18613797426223755 val_loss:  0.23530906438827515
Epoch:  394 trn_loss:  0.18590669333934784 val_loss:  0.23506662249565125
Epoch:  395 trn_loss:  0.1856766790151596 val_loss:  0.23481982946395874
Epoch:  396 trn_loss:  0.18544727563858032 val_loss:  0.23457063734531403
Epoch:  397 trn_loss:  0.18521808087825775 val_loss:  0.23432005941867828
Epoch:  398 trn_loss:  0.18498963117599487 val_loss:  0.23406894505023956
Epoch:  399 trn_loss:  0.1847625970840454 val_loss:  0.23381799459457397
Epoch:  400 trn_loss:  0.1845369040966034 val_loss:  0.2335694134235382
Epoch:  401 trn_loss:  0.18431198596954346 val_loss:  0.23332826793193817
Epoch:  402 trn_loss:  0.18408769369125366 val_loss:  0.23309241235256195
Epoch:  403 trn_loss:  0.18386472761631012 val_loss:  0.2328556329011917
Epoch:  404 trn_loss:  0.18364320695400238 val_loss:  0.2326131910085678
Epoch:  405 trn_loss:  0.18342262506484985 val_loss:  0.2323639690876007
Epoch:  406 trn_loss:  0.18320240080356598 val_loss:  0.23211108148097992
Epoch:  407 trn_loss:  0.18298254907131195 val_loss:  0.23185914754867554
Epoch:  408 trn_loss:  0.18276365101337433 val_loss:  0.2316109836101532
Epoch:  409 trn_loss:  0.18254587054252625 val_loss:  0.23136639595031738
Epoch:  410 trn_loss:  0.18232886493206024 val_loss:  0.23112532496452332
Epoch:  411 trn_loss:  0.1821124255657196 val_loss:  0.23088747262954712
Epoch:  412 trn_loss:  0.18189676105976105 val_loss:  0.23064976930618286
Epoch:  413 trn_loss:  0.18168193101882935 val_loss:  0.23040948808193207
Epoch:  414 trn_loss:  0.18146777153015137 val_loss:  0.2301664650440216
Epoch:  415 trn_loss:  0.18125398457050323 val_loss:  0.2299218475818634
Epoch:  416 trn_loss:  0.18104073405265808 val_loss:  0.22967836260795593
Epoch:  417 trn_loss:  0.18082818388938904 val_loss:  0.2294367104768753
Epoch:  418 trn_loss:  0.18061617016792297 val_loss:  0.22919532656669617
Epoch:  419 trn_loss:  0.18040448427200317 val_loss:  0.22895632684230804
Epoch:  420 trn_loss:  0.1801932156085968 val_loss:  0.22872155904769897
Epoch:  421 trn_loss:  0.17998231947422028 val_loss:  0.22848977148532867
Epoch:  422 trn_loss:  0.1797717809677124 val_loss:  0.22825580835342407
Epoch:  423 trn_loss:  0.1795615702867508 val_loss:  0.22801907360553741
Epoch:  424 trn_loss:  0.17935174703598022 val_loss:  0.22778140008449554
Epoch:  425 trn_loss:  0.1791422963142395 val_loss:  0.22754298150539398
Epoch:  426 trn_loss:  0.17893332242965698 val_loss:  0.22730758786201477
Epoch:  427 trn_loss:  0.17872470617294312 val_loss:  0.22707337141036987
Epoch:  428 trn_loss:  0.17851650714874268 val_loss:  0.2268393635749817
Epoch:  429 trn_loss:  0.17830875515937805 val_loss:  0.22660435736179352
Epoch:  430 trn_loss:  0.17810143530368805 val_loss:  0.22637046873569489
Epoch:  431 trn_loss:  0.17789465188980103 val_loss:  0.22613851726055145
Epoch:  432 trn_loss:  0.177688330411911 val_loss:  0.22590865194797516
Epoch:  433 trn_loss:  0.1774825006723404 val_loss:  0.22567956149578094
Epoch:  434 trn_loss:  0.17727720737457275 val_loss:  0.22545212507247925
Epoch:  435 trn_loss:  0.17707248032093048 val_loss:  0.22522714734077454
Epoch:  436 trn_loss:  0.17686837911605835 val_loss:  0.22500361502170563
Epoch:  437 trn_loss:  0.17666487395763397 val_loss:  0.22478093206882477
Epoch:  438 trn_loss:  0.17646196484565735 val_loss:  0.22455869615077972
Epoch:  439 trn_loss:  0.17625963687896729 val_loss:  0.22433795034885406
Epoch:  440 trn_loss:  0.1760578751564026 val_loss:  0.2241159826517105
Epoch:  441 trn_loss:  0.17585673928260803 val_loss:  0.223891943693161
Epoch:  442 trn_loss:  0.17565619945526123 val_loss:  0.22366632521152496
Epoch:  443 trn_loss:  0.1754562258720398 val_loss:  0.22344034910202026
Epoch:  444 trn_loss:  0.17525675892829895 val_loss:  0.2232150435447693
Epoch:  445 trn_loss:  0.17505782842636108 val_loss:  0.22299177944660187
Epoch:  446 trn_loss:  0.17485937476158142 val_loss:  0.22277092933654785
Epoch:  447 trn_loss:  0.17466145753860474 val_loss:  0.22255219519138336
Epoch:  448 trn_loss:  0.17446404695510864 val_loss:  0.22233271598815918
Epoch:  449 trn_loss:  0.17426715791225433 val_loss:  0.22211167216300964
Epoch:  450 trn_loss:  0.17407077550888062 val_loss:  0.221890389919281
Epoch:  451 trn_loss:  0.17387491464614868 val_loss:  0.22167092561721802
Epoch:  452 trn_loss:  0.17367950081825256 val_loss:  0.22145278751850128
Epoch:  453 trn_loss:  0.17348457872867584 val_loss:  0.22123616933822632
Epoch:  454 trn_loss:  0.17329014837741852 val_loss:  0.22102025151252747
Epoch:  455 trn_loss:  0.17309615015983582 val_loss:  0.22080308198928833
Epoch:  456 trn_loss:  0.17290261387825012 val_loss:  0.22058551013469696
Epoch:  457 trn_loss:  0.17270952463150024 val_loss:  0.22036807239055634
Epoch:  458 trn_loss:  0.1725168228149414 val_loss:  0.22015002369880676
Epoch:  459 trn_loss:  0.1723245531320572 val_loss:  0.21993033587932587
Epoch:  460 trn_loss:  0.172132670879364 val_loss:  0.21970976889133453
Epoch:  461 trn_loss:  0.17194122076034546 val_loss:  0.21949082612991333
Epoch:  462 trn_loss:  0.17175014317035675 val_loss:  0.2192760407924652
Epoch:  463 trn_loss:  0.17155952751636505 val_loss:  0.21906250715255737
Epoch:  464 trn_loss:  0.17136920988559723 val_loss:  0.21884696185588837
Epoch:  465 trn_loss:  0.17117923498153687 val_loss:  0.2186278998851776
Epoch:  466 trn_loss:  0.17098967730998993 val_loss:  0.21840721368789673
Epoch:  467 trn_loss:  0.17080043256282806 val_loss:  0.21818818151950836
Epoch:  468 trn_loss:  0.17061159014701843 val_loss:  0.21797126531600952
Epoch:  469 trn_loss:  0.17042309045791626 val_loss:  0.2177533358335495
Epoch:  470 trn_loss:  0.17023494839668274 val_loss:  0.21753354370594025
Epoch:  471 trn_loss:  0.1700470894575119 val_loss:  0.2173139452934265
Epoch:  472 trn_loss:  0.16985946893692017 val_loss:  0.21709679067134857
Epoch:  473 trn_loss:  0.16967223584651947 val_loss:  0.21688123047351837
Epoch:  474 trn_loss:  0.16948531568050385 val_loss:  0.216665580868721
Epoch:  475 trn_loss:  0.16929872334003448 val_loss:  0.21644869446754456
Epoch:  476 trn_loss:  0.169112429022789 val_loss:  0.2162313610315323
Epoch:  477 trn_loss:  0.16892649233341217 val_loss:  0.2160155028104782
Epoch:  478 trn_loss:  0.16874085366725922 val_loss:  0.21580107510089874
Epoch:  479 trn_loss:  0.16855552792549133 val_loss:  0.2155863642692566
Epoch:  480 trn_loss:  0.16837048530578613 val_loss:  0.21536988019943237
Epoch:  481 trn_loss:  0.16818572580814362 val_loss:  0.21515391767024994
Epoch:  482 trn_loss:  0.16800124943256378 val_loss:  0.2149403840303421
Epoch:  483 trn_loss:  0.1678171157836914 val_loss:  0.2147277146577835
Epoch:  484 trn_loss:  0.16763323545455933 val_loss:  0.21451307833194733
Epoch:  485 trn_loss:  0.1674496978521347 val_loss:  0.2142975926399231
Epoch:  486 trn_loss:  0.16726642847061157 val_loss:  0.21408285200595856
Epoch:  487 trn_loss:  0.16708344221115112 val_loss:  0.21386927366256714
Epoch:  488 trn_loss:  0.16690081357955933 val_loss:  0.213655024766922
Epoch:  489 trn_loss:  0.1667184829711914 val_loss:  0.21344001591205597
Epoch:  490 trn_loss:  0.16653650999069214 val_loss:  0.21322587132453918
Epoch:  491 trn_loss:  0.16635482013225555 val_loss:  0.213014155626297
Epoch:  492 trn_loss:  0.16617342829704285 val_loss:  0.21280327439308167
Epoch:  493 trn_loss:  0.16599233448505402 val_loss:  0.21259243786334991
Epoch:  494 trn_loss:  0.16581156849861145 val_loss:  0.2123820185661316
Epoch:  495 trn_loss:  0.16563113033771515 val_loss:  0.21217286586761475
Epoch:  496 trn_loss:  0.1654510796070099 val_loss:  0.21196416020393372
Epoch:  497 trn_loss:  0.1652713119983673 val_loss:  0.21175509691238403
Epoch:  498 trn_loss:  0.16509193181991577 val_loss:  0.21154600381851196
Epoch:  499 trn_loss:  0.16491292417049408 val_loss:  0.2113374024629593
Training time = 45.15088415145874 seconds

Plotting: MSE Loss for Training and Validation

In order to understand how well the model has trained we plot the training loss and validation loss as a function of Epoch in Figure 2. Figure 2 shows the MSE loss for training (blue) and validation (orange) as a function of epoch.

In [12]:
plt.plot(np.arange(len(train_loss)), train_loss, color="blue")
plt.plot(np.arange(len(train_loss)), valdn_loss, color="orange")
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.show()

Figure 2: Training and Validation MSE loss (blue, orange) as a function of Epoch.


Step 4: Testing the Model

Now that the model has been trained, testing the model is a computationally cheap proceedure. As before, we choose the data using DEMdata, and load with DataLoader. Using valtest_model, the DeepEM map is created ${\texttt{output = model(img)}}$, and the MSE loss calculated as during training.

In [13]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)
In [14]:
t0=time.time() #Timing how long it takes to predict the DEMs
dummy, test_loss, dem_pred, dem_in_test = valtest_model(dem_loader, criterion)
performance = "Number of DEM solutions per second = {0}".format((y_test.shape[2]*y_test.shape[3])/(time.time()-t0))
In [15]:
print(performance)
Number of DEM solutions per second = 8532871.017073322

Plotting: AIA, Basis Pursuit, DeepEM

With the DeepEM map calculated, we can now compare the solutions obtained by Basis Pursuit and DeepEM. Figure 3 is similar to Figure 1 with an additional row corresponding to the solutions for DeepEM. Figure 3 shows SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.

In [16]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 3: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.


Furthermore, as we have the original Basis Pursuit DEM solutions ("the ground truth"), we can compare the average DEM from Basis Pursuit to the average DEM from DeepEM, as they should be similar. Figure 4 shows the average Basis Pursuit DEM (black curve) and the DeepEM solution (light blue bars/dotted line).

In [17]:
def PlotTotalEM(em_unscaled, em_pred_unscaled, lgtaxis, status):
    mask = np.zeros([status.shape[0],status.shape[1]])
    mask[np.where(status == 0.0)] = 1.0
    nmask = np.sum(mask)
    
    EM_tru_sum = np.zeros([lgtaxis.size])
    EM_inv_sum = np.zeros([lgtaxis.size])
    
    for i in range(lgtaxis.size):
        EM_tru_sum[i] = np.sum(em_unscaled[0,i,:,:]*mask)/nmask
        EM_inv_sum[i] = np.sum(em_pred_unscaled[0,i,:,:]*mask)/nmask
        
    fig = plt.figure   
    plt.plot(lgtaxis,EM_tru_sum, linewidth=3, color="black")
    plt.plot(lgtaxis,EM_inv_sum, linewidth=3, color="lightblue", linestyle='--')
    plt.tick_params(axis='both', which='major')#, labelsize=16)
    plt.tick_params(axis='both', which='minor')#, labelsize=16)
    
    dlogT = lgtaxis[1]-lgtaxis[0]
    #plt.bar(lgtaxis-0.5*dlogT, EM_inv_sum, dlogT, linewidth=2, color='lightblue')
    #plt.bar(lgtaxis, EM_inv_sum, dlogT, linewidth=2, color='lightblue')

    plt.xlim(lgtaxis[0]-0.5*dlogT, lgtaxis.max()+0.5*dlogT)
    plt.xticks(np.arange(np.min(lgtaxis), np.max(lgtaxis),2*dlogT))
    plt.ylim(1e24,1e27)
    plt.yscale('log')
    plt.xlabel('log$_{10}$T [K]')
    plt.ylabel('Mean Emission Measure [cm$^{-5}$]')
    plt.title('Basis Pursuit (curve) vs. DeepEM (bars)')
    
    plt.show()
    return EM_inv_sum, EM_tru_sum
In [18]:
em_unscaled = em_unscale(dem_in_test.detach().cpu().numpy())
em_pred_unscaled = em_unscale(dem_pred.detach().cpu().numpy())
status = np.zeros([512,512]) #Setting statuses to zero here, but could be provided
                   
EMinv, EMTru = PlotTotalEM(em_unscaled,em_pred_unscaled,lgtaxis,status)

Figure 4: Average Basis Pursuit DEM (black line) against the Average DeepEM solution (dashed line). It is clear that this simple implementation of DeepEM provides, on average, DEMs that are similar to Basis Pursuit (Cheung et al 2015).


Step 5: Synthesize SDO/AIA Observations

Finally, it is also of interest to reconstruct the SDO/AIA observations from both the Basis Pursuit, and DeepEM solutions.

We are able to pose the problem of reconstructing the SDO/AIA observations from the DEM as a 1x1 2D Convolution. We first define the weights as the response functions of each channel, and set the biases to $zero$. By convolving the unscaled DEM at each pixel with the 6 filters (one for each SDO/AIA response function), we can recover the SDO/AIA observations.

In [19]:
# We first load the AIA response functions:
cl = np.load('./DeepEM_Data/chianti_lines_AIA.npy')
In [20]:
# Used Conv2d to convolve?? every pixel (18x1x1) by the 6 response functions
# to return a set of observed fluxes in each channel (6x1x1)
dem2aia = cudaize(nn.Conv2d(18, 6, kernel_size=1))

chianti_lines_2 = cudaize(torch.zeros(6,18,1,1))
biases = cudaize(torch.zeros(6))

# set the weights to each of the SDO/AIA response functions and biases to zero
for i, p in enumerate(dem2aia.parameters()):
    if i == 0:
        p.data = Variable(cudaize(torch.from_numpy(cl).type(torch.FloatTensor)))
    else:
        p.data = biases 
In [21]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred))).detach().cpu().numpy())

Plotting SDO/AIA Observations and Synthetic Observations

In [22]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[1].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')
ax[1].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    

Figure 5: Top to Bottom: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (left) with the corresponding synthesised observations from Basis Pursuit (middle) and DeepEM (right). DeepEM provides synthetic observations that are similar to Basis Pursuit, with the addition of solutions where the basis pursuit solution was $zero$.


Discussion

This chapter has provided a simple example of how a 1x1 2D Convolutional Neural Network can be used to improve computational cost for DEM inversion. Future improvements to DeepEM can come in a few ways:

First, by using both the original, and synthesised data from the DEM, the ability of the DEM to recover the original or supplementary data (such as spectroscopic EUV data) can be used as a additional term in the loss function.

Secondly, this implementation of DeepEM has been trained on a single set of observations. While there are 512$^{2}$ DEMs in one set of observations, it would be advisable to train the model to further images of the Sun in various states of activity.


Appendix A: What has the CNN learned about our training set?

If we say that our training set is now our test set, we can see how much the CNN has learned about the training data.

In [23]:
X_test = X_train 
y_test = y_train
In [24]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

dummy, test_loss, dem_pred_trn, dem_in_test_trn = valtest_model(dem_loader, criterion)
In [25]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test_trn))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred_trn))).detach().cpu().numpy())
In [26]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

What this shows is that even in training the model has not learned the exact mapping from specific SDO/AIA observations to DEMs, and there is sufficient generalisation that the $zero$ DEMs are not learned by the model.

Finally, we can synthesise the SDO/AIA observations, as previously.

In [27]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[1].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[0].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')
ax[1].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)